{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "PyTorch/TPU MNIST Training",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0WzFpOKDmURO",
        "colab_type": "text"
      },
      "source": [
        "## PyTorch/TPU MNIST Demo\n",
        "\n",
        "This colab example corresponds to the implementation under [test_train_mp_mnist.py](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xOp9jBEumdvC",
        "colab_type": "text"
      },
      "source": [
        "<h3>  &nbsp;&nbsp;Use Colab Cloud TPU&nbsp;&nbsp; <a href=\"https://cloud.google.com/tpu/\"><img valign=\"middle\" src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png\" width=\"50\"></a></h3>\n",
        "\n",
        "* On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n",
        "* The cell below makes sure you have access to a TPU on Colab.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Hx4YVNHametU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import os\n",
        "assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YofXQrnxmf5r",
        "colab_type": "text"
      },
      "source": [
        "### [RUNME] Install Colab TPU compatible PyTorch/TPU wheels and dependencies\n",
        "This may take up to ~2 minutes"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sPJVqAKyml5W",
        "colab_type": "code",
        "cellView": "both",
        "colab": {}
      },
      "source": [
        "VERSION = \"20200516\"  #@param [\"1.5\" , \"20200516\", \"nightly\"]\n",
        "!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
        "!python pytorch-xla-env-setup.py --version $VERSION"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OHpYeQrOLcmZ"
      },
      "source": [
        "### Define Parameters and Helpers, and start Training\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LNLegt4jLkRD",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Result Visualization Helper\n",
        "import math\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "M, N = 4, 6\n",
        "RESULT_IMG_PATH = '/tmp/test_result.png'\n",
        "\n",
        "def plot_results(images, labels, preds):\n",
        "  images, labels, preds = images[:M*N], labels[:M*N], preds[:M*N]\n",
        "  inv_norm = transforms.Normalize((-0.1307/0.3081,), (1/0.3081,))\n",
        "\n",
        "  num_images = images.shape[0]\n",
        "  fig, axes = plt.subplots(M, N, figsize=(11, 9))\n",
        "  fig.suptitle('Correct / Predicted Labels (Red text for incorrect ones)')\n",
        "\n",
        "  for i, ax in enumerate(fig.axes):\n",
        "    ax.axis('off')\n",
        "    if i >= num_images:\n",
        "      continue\n",
        "    img, label, prediction = images[i], labels[i], preds[i]\n",
        "    img = inv_norm(img)\n",
        "    img = img.squeeze() # [1,Y,X] -> [Y,X]\n",
        "    label, prediction = label.item(), prediction.item()\n",
        "    if label == prediction:\n",
        "      ax.set_title(u'\\u2713', color='blue', fontsize=22)\n",
        "    else:\n",
        "      ax.set_title(\n",
        "          'X {}/{}'.format(label, prediction), color='red')\n",
        "    ax.imshow(img)\n",
        "  plt.savefig(RESULT_IMG_PATH, transparent=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kNh-oEmHmorI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Define Parameters\n",
        "FLAGS = {}\n",
        "FLAGS['datadir'] = \"/tmp/mnist\"\n",
        "FLAGS['batch_size'] = 128\n",
        "FLAGS['num_workers'] = 4\n",
        "FLAGS['learning_rate'] = 0.01\n",
        "FLAGS['momentum'] = 0.5\n",
        "FLAGS['num_epochs'] = 10\n",
        "FLAGS['num_cores'] = 8\n",
        "FLAGS['log_steps'] = 20\n",
        "FLAGS['metrics_debug'] = False"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pTmxZL5ymp8P",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "import os\n",
        "import time\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "import torch_xla\n",
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.debug.metrics as met\n",
        "import torch_xla.distributed.parallel_loader as pl\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "import torch_xla.utils.utils as xu\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "\n",
        "SERIAL_EXEC = xmp.MpSerialExecutor()\n",
        "\n",
        "class MNIST(nn.Module):\n",
        "\n",
        "  def __init__(self):\n",
        "    super(MNIST, self).__init__()\n",
        "    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
        "    self.bn1 = nn.BatchNorm2d(10)\n",
        "    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
        "    self.bn2 = nn.BatchNorm2d(20)\n",
        "    self.fc1 = nn.Linear(320, 50)\n",
        "    self.fc2 = nn.Linear(50, 10)\n",
        "\n",
        "  def forward(self, x):\n",
        "    x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
        "    x = self.bn1(x)\n",
        "    x = F.relu(F.max_pool2d(self.conv2(x), 2))\n",
        "    x = self.bn2(x)\n",
        "    x = torch.flatten(x, 1)\n",
        "    x = F.relu(self.fc1(x))\n",
        "    x = self.fc2(x)\n",
        "    return F.log_softmax(x, dim=1)\n",
        "\n",
        "# Only instantiate model weights once in memory.\n",
        "WRAPPED_MODEL = xmp.MpModelWrapper(MNIST())\n",
        "\n",
        "def train_mnist():\n",
        "  torch.manual_seed(1)\n",
        "  \n",
        "  def get_dataset():\n",
        "    norm = transforms.Normalize((0.1307,), (0.3081,))\n",
        "    train_dataset = datasets.MNIST(\n",
        "        FLAGS['datadir'],\n",
        "        train=True,\n",
        "        download=True,\n",
        "        transform=transforms.Compose(\n",
        "            [transforms.ToTensor(), norm]))\n",
        "    test_dataset = datasets.MNIST(\n",
        "        FLAGS['datadir'],\n",
        "        train=False,\n",
        "        download=True,\n",
        "        transform=transforms.Compose(\n",
        "            [transforms.ToTensor(), norm]))\n",
        "    \n",
        "    return train_dataset, test_dataset\n",
        "  \n",
        "  # Using the serial executor avoids multiple processes to\n",
        "  # download the same data.\n",
        "  train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)\n",
        "\n",
        "  train_sampler = torch.utils.data.distributed.DistributedSampler(\n",
        "    train_dataset,\n",
        "    num_replicas=xm.xrt_world_size(),\n",
        "    rank=xm.get_ordinal(),\n",
        "    shuffle=True)\n",
        "  train_loader = torch.utils.data.DataLoader(\n",
        "      train_dataset,\n",
        "      batch_size=FLAGS['batch_size'],\n",
        "      sampler=train_sampler,\n",
        "      num_workers=FLAGS['num_workers'],\n",
        "      drop_last=True)\n",
        "  test_loader = torch.utils.data.DataLoader(\n",
        "      test_dataset,\n",
        "      batch_size=FLAGS['batch_size'],\n",
        "      shuffle=False,\n",
        "      num_workers=FLAGS['num_workers'],\n",
        "      drop_last=True)\n",
        "\n",
        "  # Scale learning rate to world size\n",
        "  lr = FLAGS['learning_rate'] * xm.xrt_world_size()\n",
        "\n",
        "  # Get loss function, optimizer, and model\n",
        "  device = xm.xla_device()\n",
        "  model = WRAPPED_MODEL.to(device)\n",
        "  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum'])\n",
        "  loss_fn = nn.NLLLoss()\n",
        "\n",
        "  def train_loop_fn(loader):\n",
        "    tracker = xm.RateTracker()\n",
        "    model.train()\n",
        "    for x, (data, target) in enumerate(loader):\n",
        "      optimizer.zero_grad()\n",
        "      output = model(data)\n",
        "      loss = loss_fn(output, target)\n",
        "      loss.backward()\n",
        "      xm.optimizer_step(optimizer)\n",
        "      tracker.add(FLAGS['batch_size'])\n",
        "      if x % FLAGS['log_steps'] == 0:\n",
        "        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(\n",
        "            xm.get_ordinal(), x, loss.item(), tracker.rate(),\n",
        "            tracker.global_rate(), time.asctime()), flush=True)\n",
        "\n",
        "  def test_loop_fn(loader):\n",
        "    total_samples = 0\n",
        "    correct = 0\n",
        "    model.eval()\n",
        "    data, pred, target = None, None, None\n",
        "    for data, target in loader:\n",
        "      output = model(data)\n",
        "      pred = output.max(1, keepdim=True)[1]\n",
        "      correct += pred.eq(target.view_as(pred)).sum().item()\n",
        "      total_samples += data.size()[0]\n",
        "\n",
        "    accuracy = 100.0 * correct / total_samples\n",
        "    print('[xla:{}] Accuracy={:.2f}%'.format(\n",
        "        xm.get_ordinal(), accuracy), flush=True)\n",
        "    return accuracy, data, pred, target\n",
        "\n",
        "  # Train and eval loops\n",
        "  accuracy = 0.0\n",
        "  data, pred, target = None, None, None\n",
        "  for epoch in range(1, FLAGS['num_epochs'] + 1):\n",
        "    para_loader = pl.ParallelLoader(train_loader, [device])\n",
        "    train_loop_fn(para_loader.per_device_loader(device))\n",
        "    xm.master_print(\"Finished training epoch {}\".format(epoch))\n",
        "\n",
        "    para_loader = pl.ParallelLoader(test_loader, [device])\n",
        "    accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))\n",
        "    if FLAGS['metrics_debug']:\n",
        "      xm.master_print(met.metrics_report(), flush=True)\n",
        "\n",
        "  return accuracy, data, pred, target"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Afwo4H7kSd8P",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Start training processes\n",
        "def _mp_fn(rank, flags):\n",
        "  global FLAGS\n",
        "  FLAGS = flags\n",
        "  torch.set_default_tensor_type('torch.FloatTensor')\n",
        "  accuracy, data, pred, target = train_mnist()\n",
        "  if rank == 0:\n",
        "    # Retrieve tensors that are on TPU core 0 and plot.\n",
        "    plot_results(data.cpu(), pred.cpu(), target.cpu())\n",
        "\n",
        "xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],\n",
        "          start_method='fork')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MznTE72_mthI",
        "colab_type": "text"
      },
      "source": [
        "## Visualize Predictions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "X9VAwyUnI7Sb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from google.colab.patches import cv2_imshow\n",
        "import cv2\n",
        "img = cv2.imread(RESULT_IMG_PATH, cv2.IMREAD_UNCHANGED)\n",
        "cv2_imshow(img)"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
